from lib.base import BaseLoss
import torch
from torch.nn import functional as F


class Experiment1Loss(BaseLoss):
    def __init__(self, cfg):
        super().__init__(cfg)
        self.cfg = cfg

    def forward(self, logits, labels, type):

        if type == 'TEST':
            loss = self.episodic_loss(logits, labels, type)
        else:
            loss = self.cross_entropy(logits, labels)
        return loss
